#!/bin/bash -e
#
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

tests_dir="@PYTHON_TEST_DIR@"
python_versions="@PYTHON_TEST_VERSIONS@"

# Verify if correct package dependencies are installed --------
pip_depends="numba torch"

declare -a install_commands

for ver in $python_versions; do
    if ! python$ver -c "import numba, torch" > /dev/null 2>&1; then
        install_commands+=("sudo python$ver -m pip install $pip_depends")
    fi
done

# numba needs nvcc
if ! python3 -c "import numba; from numba import cuda; assert cuda.is_available()" > /dev/null 2>&1; then
    cuda_ver=$(ls -1d /usr/local/cuda-* | sort -rV | head -1 | sed -n "s@[^-]\+-\(.*\)@\1@p")
    install_commands+=("sudo apt-get install cuda-nvcc-$cuda_ver")
fi

if [[ "${install_commands[*]}" ]]; then
    echo "Please run the following commands before running $(basename $0): "
    ( IFS=$'\n'; echo -e "${install_commands[*]}" )
    exit 1
fi

# Run tests --------

tmpdir=$(mktemp -d)

function on_exit()
{
    rm -rf $tmpdir
}
trap 'on_exit' EXIT

for ver in $python_versions; do

    if [[ "$NVCV_FORCE_PYTHON" != 1 && "$NVCV_FORCE_PYTHON" != yes ]]; then
        if ! PYTHONPATH="$PYTHONPATH:@PYTHON_MODULE_DIR@" python$ver -c 'import nvcv'; then
            echo "Skipping python-$ver, NVCV python bindings not installed"
            continue
        fi
    fi

    PYTHONPATH="$PYTHONPATH:@PYTHON_MODULE_DIR@" NVCV_VERSION="@NVCV_VERSION_FULL@" python$ver -m pytest -o cache_dir="$tmpdir" "$@" "$tests_dir"
done
